import numpy as np
from qutip import *
import matplotlib.pyplot as plt
from scipy.special import factorial, hermite
import math
import time
import os
tic = time.time() # Start timer to measure total execution time
# ============================================================================
# Monte Carlo simulation of a multimode Coherent Ising Machine (CIM)
# using quantum jump (MCWF) methods in QuTiP.
# Features:
# - Simulates coherent coupling between multiple DOPO modes
# - Compares vacuum vs coherent superposition initial states
# - Supports time-dependent parameters
# - Coupling matrix J determines the Ising Problem
# - Computes success rate and time/sampling error estimates
# ============================================================================

# Define simulation parameters in a class
class Parameters:
    g = 0.6                              # Nonlinear dissipation rate
    alpha = np.sqrt(2.4 / (g ** 2))      # Coherent amplitude (√(λ/g²))
    N = 15                               # Photon number cutoff
    M = 2                                # Number of modes
    gamma = 1                            # Linear dissipation rate
    lam = g ** 2 * alpha ** 2            # Pump rate λ
    det = 0                              # Detuning
    t = 2                                # Total simulation time
    steps = 400                          # Number of time steps
    times = np.linspace(0, t, steps + 1) # Time array
    traj = 500                           # Number of trajectories
    runs = 10                            # Number of independent runs
    n_val = np.arange(0, N)              # Fock states
    JJ = 1                              # Coupling sign, Ferro(+)/ Anti-Ferro (-)
    J = np.zeros([M, M])                 # Coupling matrix
    S = np.zeros((2 ** M, M))            # Spin configuration matrix
    index = 0                            # Select Initial state [0,1,2] = {"Vaccum State", "Sup_State", "Ent_cat_State"}
    timestep = 1                         # Enable time step error analysis

p = Parameters()
a = destroy(p.N) # Annihilation operator for one mode

# Create nearest-neighbor coupling matrix J
for m in range(p.M):
    np.random.seed(m)  # Fixed seed
    if m+1==p.M:
        p.J[m,0] = 1 #np.random.rand()
        p.J[0, m] = p.J[m,0]
    else:
        p.J[m,m+1] = 1 #np.random.rand()*(-1)
        p.J[m+1,m] = p.J[m,m+1]
p.J = p.JJ * p.J
p.J[0, 1] = 1  # Explicitly set coupling
p.J[1, 0] = 1
# ============================================================================
# Time-dependent parameter functions
# ============================================================================

# Paramtric functions
def col_g(t,args):
    # Nonlinear dissipation as a function of time (can be made dynamic)
    return p.g

def H_lam(t,args):
    # Pump term (Evaluated from g and alpha)
    return 1j / 2 * p.g ** 2 * p.alpha ** 2

def col_gamma(t,args):
    # Linear dissipation
    return np.sqrt(2 * p.gamma)

def col_J(t,args):
    # Time-dependence for coupling terms (can be made dynamic)
    return 1

# ============================================================================
# Operator construction
# ============================================================================

# Create annihilation operator for mode m in tensor form
def an(m):
    ops = [a if n == m else qeye(p.N) for n in range(p.M)]
    return tensor(ops)

# Construct system Hamiltonian and collapse operators
H = []
col = []
for n in range(p.M):
    H.append([an(n).dag() * an(n).dag() - an(n) * an(n), H_lam])       # Squeezing term
    col.append([an(n), col_gamma])                                    # Linear loss
    col.append([an(n) * an(n), col_g])                                # Nonlinear loss

# Add coupling terms from J matrix
for n in range(p.M):
    for m in range(p.M):
        if p.J[n, m] != 0:
            col.append([
                np.sqrt(np.abs(p.J[n, m])) *
                (an(n) - (p.J[n, m]/np.abs(p.J[n, m])) * an(m)), col_J
            ])

# ============================================================================
# Initial state preparation
# ============================================================================
# Define vacuum and coherent superposition states
a1 = []
a2 = []
psi = []  # Initialise state list
for n in range (p.M): # Loop through modes
    a1.append(basis(p.N,0))                    # Vacuum state
    a2.append((coherent(p.N,p.alpha,method='analytic')+coherent(p.N,-p.alpha,method='analytic')))  # Coherent Superposition
psi.append(tensor(a1))  # Vac state
psi3 = tensor(a2)  # COH_Sup state
psi.append(psi3 / psi3.norm())# Normalise psi3 and add to list
psi2 = psi[0] * 0  # Initialise COH_Ent state
a3 = []  # Intialise psi operator for nth mode
for n in range (p.M): # Loop through modes
    aa = []
    for m in range(p.M): # Assign mode list
        if n==m:
            aa.append((coherent(p.N,p.alpha,method='analytic')+coherent(p.N,-p.alpha,method='analytic')))
        else:
            aa.append(basis(p.N,0))
    a3 = tensor(aa)
    psi2 = psi2 + a3
psi.append(psi2 / psi2.norm())  # Normalise psi2 ad add to list
psi0 = psi[p.index]  # Assign initial state

# ============================================================================
# Spin observables and Ising solution
# ============================================================================

def Hfun(M,N,S):
    # Hermite polynomial integral used in spin projectors
    a0 = 0
    for m in range(int(np.floor(M / 2)) + 1):
        for n in range(int(np.floor(N / 2)) + 1):
            a0 += S ** (M + N - 2 * n - 2 * m) * (-1) ** (m + n) / (
                factorial(m) * factorial(n) * factorial(M - 2 * m) * factorial(N - 2 * n)
            ) * 2 ** (-2 * n - 2 * m - 1) * math.gamma((M + N - 2 * n - 2 * m + 1) / 2)
    return a0 * factorial(M) * factorial(N) * 2 ** (M + N) / np.sqrt(
        np.pi * 2 ** (M + N) * factorial(M) * factorial(N)
    )

# Spin projectors in Fock basis
Sup = np.zeros((p.N, p.N))
Sdwn = np.zeros((p.N, p.N))
for n in range(p.N):
    for m in range(p.N):
        Sup[n, m] = Hfun(n, m, 1)
        Sdwn[n, m] = Hfun(n, m, -1)

# Generate spin configurations
for m in range(p.M):
    p.S[:, m] = np.tile(np.repeat(np.arange(-1, 2, 2), 2 ** m), 2 ** (p.M - m - 1))

# Classical Ising ground-state energy for each configuration
sol = np.zeros((len(p.S), 1))
for q in range(len(p.S)):
    val = -np.dot(p.S[q, :], p.J)
    sol[q] = np.dot(val, p.S[q, :].T)

# Create spin operators in tensor form
SSup, SSdwn = [], []
for n in range (p.M): # Loop through modes
    aa = []
    bb = []
    for m in range(p.M): # Assign mode list
        if n==m:
            aa.append(Qobj(Sup))    # Assign spin Matrices
            bb.append(Qobj(Sdwn))
        else:
            aa.append(qeye(p.N))
            bb.append(qeye(p.N))
    SSup.append(tensor(aa))
    SSdwn.append(tensor(bb))

# ============================================================================
# Expectation value functions
# ============================================================================

def exp(t, state):
    """Compute success rate based on projection onto lowest Ising configurations."""
    SR = 0
    for s in range(len(p.S)):  # Loop over Spin Configurations
        if sol[s] == min(sol): # If the minimum solution
            psi =state
            for n in range(p.M):  # Loop over modes
                if p.S[s,p.M-n-1]==1:
                    psi = SSup[p.M-n-1]*psi
                else:
                    psi = SSdwn[p.M - n - 1]*psi
            SR = SR + Qobj(state).dag()*psi
    return SR

def Nn(m):
    """Photon number operator for mode m."""
    ops = [a.dag() * a if n == m else qeye(p.N) for n in range(p.M)]
    return tensor(ops)

def Nexp(t, state):
    """Return max photon number across all modes."""
    return max(Qobj(state).dag() * Nn(n) * state for n in range(p.M))

# ============================================================================
# Run MCWF simulations with time/sampling error analysis
# ============================================================================

options = Options()
options.store_final_state = False
options.store_states = False    # assign 'True' here to save Raw trajectories

print('Starting QuTiP simulation. Elapsed time: {:.2f} s'.format(time.time() - tic))
SRate = np.zeros([p.runs, len(p.times)])

if p.timestep == 1:
    f_steps = p.steps * 2  # Fine grain number of steps
    f_times = np.linspace(0, p.t, f_steps + 1)
    f_SRate = np.zeros([p.runs, len(f_times)])

for r in range(p.runs):
    options.seeds = []  # Random seeds for MC trajectories

    if p.timestep == 1: # If Times step error calculation is needed
        f_psi = mcsolve(H, psi0, f_times, col, [exp], ntraj=p.traj, options=options, progress_bar=False)
        options.seeds = f_psi.seeds # Store the seeds for rng
        f_SRate[r, :] = f_psi.expect[0].real

    psi_1 = mcsolve(H, psi0, p.times, col, [exp], ntraj=p.traj, options=options, progress_bar=False)
    SRate[r, :] = psi_1.expect[0].real
    # Print Progress
    print(f"{(r + 1) / p.runs * 100:.1f}% complete, elapsed time: {time.time() - tic:.2f} s")


# ============================================================================
# Post-processing and result output
# ============================================================================

toc = time.time() - tic  # Total time leapsed
print(f"CPU time: {toc:.2f} s")  # Print CPU time
SS = SRate.mean(axis=0)

if p.timestep == 1:
    # Fine grain and coarse grain difference
    diff = np.abs(SS[1:] - f_SRate.mean(axis=0)[1::2])
    #  Display Normailsed Time step error
    print("Time Step Error:", np.sqrt(np.mean(diff ** 2)) / max(SS))

# Calculate Sampling error for each time point
sampl_err = SRate.std(axis=0) / np.sqrt(p.runs)
# Display Normailsed Sampling error
print("Sampling Error:", np.sqrt(np.mean(sampl_err ** 2)) / max(SS))
# Plot Success Rate
plt.plot(p.times, SS)
plt.legend(["Success Rate"], loc="lower right")
#np.savetxt('SR_array.txt', SS) # Uncomment if  SRate array needs saving
print("Max Success Rate:", max(SS))
print("Time of Max Success:", p.times[np.argmax(SS)])
print(f'M={p.M}, N={p.N}, alpha={round(p.alpha, 2)}, g={p.g}, lam={p.lam}, steps={p.steps}, gamma={p.gamma}')
plt.show()